import os
import cv2
import time
import numpy as np
import threading
import matplotlib.pyplot as plt

from skimage import io
from sklearn.cluster import KMeans

imageDir = os.path.abspath("images")
filenames = os.listdir(imageDir)
cacheDir = "cache"

class KMeans0(object):

	def __init__(self, k = 256, it = 8):
		self.colors = None
		self.epcho = 0
		self.k = k
		self.iter = it

	def distance(self, colors, color):
		return np.sum(np.power(colors - color, 2), axis=1)

	def fit(self, img):
		# 长宽高
		width, height, channel = img.shape
		# 3维转2维
		img = img.reshape((width * height, channel))
		# 随机取点
		colorIndexs = np.random.randint(0, high = width * height, size = self.k, dtype = 'l')
		self.colors = np.array(img[colorIndexs], dtype=np.float32)
		# 迭代次数
		self.epcho = 0
		while True:
			# 每类总颜色值
			C = np.array(self.colors, dtype=np.float32)
			# 每类个数
			CCount = np.ones(shape=(self.k, 1), dtype=np.float32)
			# 为每个像素分类
			for i in range(width * height):
				# 像素对应的类
				index = np.argmin(self.distance(self.colors, img[i]))
				# 对应类中加入这个颜色
				C[index] += img[i]
				# 对应类的颜色数量+1
				CCount[index] += 1
			# 记录上次的类聚中心
			oldColors = np.array(self.colors, dtype=np.uint8)
			# 计算新类聚中心
			self.colors = np.array(C // CCount, dtype=np.float32)
			# 迭代次数+1
			self.epcho += 1
			# 中心点没变更, 拟合完成
			if np.sum(oldColors == np.array(self.colors, dtype=np.uint8)) == self.k * 3:
				break
			# epcho上限退出迭代
			if self.epcho > self.iter:
				break

	def extractColor(self):
		return np.array(self.colors, dtype=np.uint8)

def KmeansColor(imgPath):
	startTime = time.time()
	# 读取图片
	img = cv2.imdecode(np.fromfile(imgPath, dtype=np.uint8), -1)
	img = np.array(img)
	img = img[:16, :16, :]
	kmeans = KMeans0(16)
	kmeans.fit(img)
	print("time: {0}\n color: {1}".format(time.time() - startTime, kmeans.extractColor()))

def skKmeansColor(imgPath):
	startTime = time.time()
	img = cv2.imdecode(np.fromfile(os.path.join(imageDir, filenames[0]), dtype=np.uint8), -1)
	img = img[:16, :16, :]
	width, height, channel = img.shape
	img = img.reshape((width * height, channel))
	estimator = KMeans(n_clusters=16, max_iter=8, init='k-means++')
	estimator.fit(img)
	print("time: {0}\n color: {1}".format(time.time() - startTime, np.array(estimator.cluster_centers_, dtype=np.uint8)))

def main():
	KmeansColor(os.path.join(imageDir, filenames[0]))
	skKmeansColor(os.path.join(imageDir, filenames[0]))

if "__main__" == __name__:
	main()
	exit(0)